{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments on pooling\n",
    "* non-uniform sampling scheme\n",
    "* random part of sphere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "sys.path.append(\"..\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import healpy as hp\n",
    "from tqdm import tqdm\n",
    "\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import cartopy.crs as ccrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pathfig = './figures/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Non-uniform sampling scheme"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using GHCN data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN.GHCN_preprocessing import get_data, get_stations, sphereGraph\n",
    "\n",
    "datapath = \"/mnt/nas/LTS2/datasets/ghcn-daily/processed/\"\n",
    "years = np.arange(2010,2015)\n",
    "features = [\"TMIN\"]\n",
    "\n",
    "n_stations, ghcn_to_local, lat, lon, _, _ = get_stations(datapath, years)\n",
    "data, n_days = get_data(datapath, years, features, ghcn_to_local)\n",
    "\n",
    "print(f'n_stations: {n_stations}, n_days: {n_days}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.transpose((1, 0, 2))\n",
    "keepToo = ~np.isnan(dataset).any(axis=0)\n",
    "keepToo = keepToo.all(axis=1)\n",
    "dataset = dataset[:, keepToo, :]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Original signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(25, 25))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "zmin, zmax = -20, 40\n",
    "\n",
    "sc = ax.scatter(lon[keepToo], lat[keepToo], s=10,\n",
    "                c=np.clip(data[keepToo, 0, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using spectral clustering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use KNN graph, and merge the spectral cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 5\n",
    "NCluster = dataset.shape[1]//k\n",
    "g_1 = sphereGraph(lon[keepToo], lat[keepToo], 10, rad=False, epsilon=False)\n",
    "g_1.plot()\n",
    "g_1.compute_laplacian('combinatorial')\n",
    "g_1.compute_fourier_basis(n_eigenvectors=NCluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "eig_vectors = g_1.U[:,1:NCluster+1]\n",
    "clusters = KMeans(n_clusters=NCluster).fit_predict(eig_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clust_lon, clust_lat = np.empty(NCluster), np.empty(NCluster)\n",
    "size = dataset.shape\n",
    "size = list(size)\n",
    "size[1] = NCluster\n",
    "size = tuple(size)\n",
    "new_map = np.zeros(size)\n",
    "pool = 'max'\n",
    "pool_fun = getattr(np, pool)\n",
    "for i in range(NCluster):\n",
    "    indices = np.where(clusters==i)[0]\n",
    "    clust_lon[i] = lon[keepToo][indices].mean()\n",
    "    clust_lat[i] = lat[keepToo][indices].mean()\n",
    "    data_p = dataset[:, indices, :]\n",
    "    new_map[:,i,:] = pool_fun(data_p, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(25, 25))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "zmin, zmax = -20, 40\n",
    "\n",
    "sc = ax.scatter(clust_lon, clust_lat, s=50,\n",
    "                c=np.clip(new_map[0,:,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NCluster_2 = new_map.shape[1]//k\n",
    "g_2 = sphereGraph(clust_lon, clust_lat, 10, rad=False, epsilon=False)\n",
    "fig = plt.figure(figsize=(25,25))\n",
    "axes = fig.add_subplot(111, projection='3d')\n",
    "g_2.plot(vertex_size=50, ax=axes)\n",
    "g_2.compute_laplacian('combinatorial')\n",
    "g_2.compute_fourier_basis(n_eigenvectors=NCluster_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eig_vectors2 = g_2.U[:,1:NCluster_2+1]\n",
    "clusters2 = KMeans(n_clusters=NCluster_2).fit_predict(eig_vectors2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clust_lon2, clust_lat2 = np.empty(NCluster_2), np.empty(NCluster_2)\n",
    "size = new_map.shape\n",
    "size = list(size)\n",
    "size[1] = NCluster_2\n",
    "size = tuple(size)\n",
    "new_map2 = np.zeros(size)\n",
    "pool = 'max'\n",
    "pool_fun = getattr(np, pool)\n",
    "for i in range(NCluster_2):\n",
    "    indices = np.where(clusters2==i)[0]\n",
    "    clust_lon2[i] = clust_lon[indices].mean()\n",
    "    clust_lat2[i] = clust_lat[indices].mean()\n",
    "    data_p = new_map[:, indices, :]\n",
    "    new_map2[:,i,:] = pool_fun(data_p, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(25, 25))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "zmin, zmax = -20, 40\n",
    "\n",
    "sc = ax.scatter(clust_lon2, clust_lat2, s=40,\n",
    "                c=np.clip(new_map2[0,:,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO: implement these operation in Tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Sparsify the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pygsp as pg\n",
    "from pygsp.reduction import graph_multiresolution, graph_sparsify"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gNew1 = graph_sparsify(g_1, 0.8)\n",
    "## find a way to find the new coordinates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using uniform sampling scheme"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import healpy as hp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pool = 'max' # in ['max', 'average', ...]\n",
    "Nside = 64\n",
    "theta = lon[keepToo]#np.deg2rad(lon[keepToo])\n",
    "phi = lat[keepToo]#np.deg2rad(lat[keepToo])\n",
    "# dataset_temp\n",
    "pix, weights = hp.get_interp_weights(Nside, theta, phi, nest=True, lonlat=True)\n",
    "indexes = np.unique(pix)\n",
    "size = dataset.shape\n",
    "size = list(size)\n",
    "size[1] = len(indexes) # hp.nside2npix(Nside)\n",
    "size = tuple(size)\n",
    "new_map = np.zeros(size)\n",
    "# new_map[new_map==0] = hp.UNSEEN\n",
    "pool_fun = getattr(np, pool)\n",
    "for i, index in enumerate(indexes):\n",
    "    pl = np.where(pix==index)\n",
    "#     if pl[0].shape[0] == 1:\n",
    "#         continue\n",
    "    wght = 1/(weights[pl]+1e-8)\n",
    "    wght[wght>1] = 1\n",
    "    data_p = wght[np.newaxis,:,np.newaxis] * dataset[:, pl[1], :]\n",
    "    new_map[:,i,:] = pool_fun(data_p, axis=1)\n",
    "new_lon, new_lat = hp.pix2ang(Nside, indexes, nest=True, lonlat=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(25, 25))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "zmin, zmax = -20, 40\n",
    "\n",
    "sc = ax.scatter(new_lon, new_lat, s=40,\n",
    "                c=np.clip(new_map[0, :,0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part of sphere"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ds_index(index, nsides, nest=True):\n",
    "    \"\"\"Return list of indexes sampled at specific nsides.\n",
    "    \n",
    "    The given index must be sampled at the first nside given\n",
    "    Parameters\n",
    "    ----------\n",
    "    index : list of pixel position for part of sphere\n",
    "    nsides : list of nside for the desired scale\n",
    "    \"\"\"\n",
    "    assert isinstance(nsides, list)\n",
    "    assert len(nsides) > 1\n",
    "    assert nest  # not implemented yet\n",
    "    \n",
    "    indexes = [index]\n",
    "    for nside in nsides[1:]:\n",
    "        p = (nsides[0]/nside)**2\n",
    "        if p < 1:\n",
    "            raise NotImplementedError(\"upsampling not implemented yet\")\n",
    "        temp_index = index//p\n",
    "        indexes.append(np.unique(temp_index).astype(int))            \n",
    "    \n",
    "    return indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pool_part_max(x, p, Nside, index):\n",
    "        \"\"\"Max pooling of size p on partial sphere. Sould be a power of 2.\"\"\"\n",
    "        if p > 1:\n",
    "            indexes = ds_index(index, [Nside, Nside//(p**0.5)])[1]\n",
    "            full_map = tf.ones([x.shape[0], hp.nside2npix(Nside), x.shape[2]]) * -1e8\n",
    "#             j = 0\n",
    "#             full_map = []\n",
    "#             for i in tqdm(range(hp.nside2npix(Nside))):\n",
    "#                 if i in index:\n",
    "#                     full_map.append(x[:,j,:])\n",
    "#                     j += 1\n",
    "#                 else:\n",
    "#                     full_map.append(tf.ones([x.shape[0], x.shape[2]]) * -1e8)\n",
    "            for i, ind in tqdm(enumerate(index)):\n",
    "                new_full_map = tf.Variable(full_map, trainable=False)[:,ind,:].assign(x[:,i,:])\n",
    "#             full_map[:,index.astype(np.int32),:].assign(x[:,:,:])\n",
    "#             full_map = tf.stack(full_map, axis=1)\n",
    "            new_full_map = tf.expand_dims(new_full_map, 3)\n",
    "            new_full_map = tf.nn.max_pool(new_full_map, ksize=[1,p,1,1], strides=[1,p,1,1], padding='SAME')\n",
    "            new_full_map = tf.squeeze(new_full_map, [3])\n",
    "            x_new = tf.gather(full_map, indexes, axis=1)\n",
    "            return x_new\n",
    "        else:\n",
    "            return x, _, _"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Takes too much time\n",
    "\n",
    "Must find a way to make it faster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "x = tf.placeholder(dtype=tf.float32, shape = new_map.shape, name='input_data')\n",
    "feed_dict = {x: new_map}\n",
    "op_pool = pool_part_max(x, 4, Nside, indexes)\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "new_map2 = sess.run(op_pool, feed_dict=feed_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_lon2, new_lat2 = hp.pix2ang(Nside//2, ds_index(indexes, [Nside, Nside//(4**0.5)])[1], nest=True, lonlat=True)\n",
    "fig = plt.figure(figsize=(25, 50))\n",
    "ax = fig.add_subplot(2, 1, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "zmin, zmax = -20, 40\n",
    "\n",
    "sc = ax.scatter(new_lon, new_lat, s=40,\n",
    "                c=np.clip(new_map[0, :, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "ax = fig.add_subplot(2, 1, 2, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "\n",
    "sc = ax.scatter(new_lon2, new_lat2, s=40,\n",
    "                c=np.clip(new_map2[0, :, 0], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
